{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "29317d99-1605-410d-b838-fbc3469e4856",
   "metadata": {},
   "source": [
    "# Local Tool Calling Agent\n",
    "\n",
    "Here, we'll build a [tool calling agent](https://python.langchain.com/v0.1/docs/modules/agents/agent_types/tool_calling/) using local models.\n",
    "\n",
    "We'll use the [new fine-tune from Groq](https://wow.groq.com/introducing-llama-3-groq-tool-use-models/) with tool calling via Ollama:\n",
    "\n",
    "Access the model:\n",
    "\n",
    "```\n",
    "ollama pull llama3-groq-tool-use\n",
    "ollama pull llama3.1\n",
    "```\n",
    "\n",
    "And also, we'll use the Ollama partner package.\n",
    "\n",
    "This notebook accompanies the video here:\n",
    "\n",
    "https://www.youtube.com/watch?v=Nfk99Fz8H9k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "666175b4-f056-489d-9d94-a0a666270794",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-xxx\"\n",
    "os.environ[\"TAVILY_API_KEY\"] = \"tvly-xxx\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "120c1da8-e45e-4ffa-9ac1-a536026c7e1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install -qU langchain-ollama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "32c0504b-007a-4af6-9976-c7294ed26b73",
   "metadata": {},
   "outputs": [],
   "source": [
    "# /// LLM ///\n",
    "\n",
    "from langchain_ollama import ChatOllama\n",
    "\n",
    "llm = ChatOllama(\n",
    "    # model=\"llama3-groq-tool-use\",\n",
    "    model=\"llama3.1\",\n",
    "    temperature=0,\n",
    ")\n",
    "\n",
    "# /// Retriever tool ///\n",
    "\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "from langchain_community.vectorstores import SKLearnVectorStore\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "# List of URLs to load documents from\n",
    "urls = [\n",
    "    \"https://lilianweng.github.io/posts/2023-06-23-agent/\",\n",
    "    \"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/\",\n",
    "    \"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/\",\n",
    "]\n",
    "\n",
    "# Load documents from the URLs\n",
    "docs = [WebBaseLoader(url).load() for url in urls]\n",
    "docs_list = [item for sublist in docs for item in sublist]\n",
    "\n",
    "# Initialize a text splitter with specified chunk size and overlap\n",
    "text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
    "    chunk_size=250, chunk_overlap=0\n",
    ")\n",
    "\n",
    "# Split the documents into chunks\n",
    "doc_splits = text_splitter.split_documents(docs_list)\n",
    "\n",
    "# Add the document chunks to the \"vector store\" using OpenAIEmbeddings\n",
    "vectorstore = SKLearnVectorStore.from_documents(\n",
    "    documents=doc_splits,\n",
    "    embedding=OpenAIEmbeddings(),\n",
    ")\n",
    "retriever = vectorstore.as_retriever(k=4)\n",
    "\n",
    "\n",
    "# Define a tool, which we will connect to our agent\n",
    "def retrieve_documents(query: str) -> list:\n",
    "    \"\"\"Retrieve documents from the vector store based on the query.\"\"\"\n",
    "    return retriever.invoke(query)\n",
    "\n",
    "\n",
    "# /// Search Tool\n",
    "\n",
    "from langchain_community.tools.tavily_search import TavilySearchResults\n",
    "\n",
    "from langchain.schema import Document\n",
    "\n",
    "web_search_tool = TavilySearchResults()\n",
    "\n",
    "\n",
    "def web_search(query: str) -> str:\n",
    "    \"\"\"Run web search on the question.\"\"\"\n",
    "    web_results = web_search_tool.invoke({\"query\": query})\n",
    "    return [\n",
    "        Document(page_content=d[\"content\"], metadata={\"url\": d[\"url\"]})\n",
    "        for d in web_results\n",
    "    ]\n",
    "\n",
    "# Tool list\n",
    "tools = [retrieve_documents, web_search]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "30052f47-2b5d-46f5-9873-eb716145cda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Annotated, List\n",
    "\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import Runnable, RunnableConfig\n",
    "from langgraph.graph.message import AnyMessage, add_messages\n",
    "from typing_extensions import TypedDict\n",
    "\n",
    "class State(TypedDict):\n",
    "    messages: Annotated[list[AnyMessage], add_messages]\n",
    "\n",
    "class Assistant:\n",
    "    def __init__(self, runnable: Runnable):\n",
    "        \"\"\"\n",
    "        Initialize the Assistant with a runnable object.\n",
    "\n",
    "        Args:\n",
    "            runnable (Runnable): The runnable instance to invoke.\n",
    "        \"\"\"\n",
    "        self.runnable = runnable\n",
    "\n",
    "    def __call__(self, state: State, config: RunnableConfig):\n",
    "        \"\"\"\n",
    "        Call method to invoke the LLM and handle its responses.\n",
    "        Re-prompt the assistant if the response is not a tool call or meaningful text.\n",
    "\n",
    "        Args:\n",
    "            state (State): The current state containing messages.\n",
    "            config (RunnableConfig): The configuration for the runnable.\n",
    "\n",
    "        Returns:\n",
    "            dict: The final state containing the updated messages.\n",
    "        \"\"\"\n",
    "        while True:\n",
    "            result = self.runnable.invoke(state)  # Invoke the LLM\n",
    "            if not result.tool_calls and (\n",
    "                not result.content\n",
    "                or isinstance(result.content, list)\n",
    "                and not result.content[0].get(\"text\")\n",
    "            ):\n",
    "                messages = state[\"messages\"] + [(\"user\", \"Respond with a real output.\")]\n",
    "                state = {**state, \"messages\": messages}\n",
    "            else:\n",
    "                break\n",
    "        return {\"messages\": result}\n",
    "\n",
    "\n",
    "# Create the primary assistant prompt template\n",
    "primary_assistant_prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"You are a helpful assistant tasked with answering user questions. \"\n",
    "            \"You have access to two tools: retrieve_documents and web_search. \"\n",
    "            \"For any user questions about LLM agents, use the retrieve_documents tool to get information for a vectorstore. \"\n",
    "            \"For any other questions, such as questions about current events, use the web_search tool to get information from the web. \",\n",
    "        ),\n",
    "        (\"placeholder\", \"{messages}\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Prompt our LLM and bind tools\n",
    "assistant_runnable = primary_assistant_prompt | llm.bind_tools(tools)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "40504a0b-8a99-4420-a6bf-561c62e893d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/4gHYSUNDX1BST0ZJTEUAAQEAAAHIAAAAAAQwAABtbnRyUkdCIFhZWiAH4AABAAEAAAAAAABhY3NwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAA9tYAAQAAAADTLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAlkZXNjAAAA8AAAACRyWFlaAAABFAAAABRnWFlaAAABKAAAABRiWFlaAAABPAAAABR3dHB0AAABUAAAABRyVFJDAAABZAAAAChnVFJDAAABZAAAAChiVFJDAAABZAAAAChjcHJ0AAABjAAAADxtbHVjAAAAAAAAAAEAAAAMZW5VUwAAAAgAAAAcAHMAUgBHAEJYWVogAAAAAAAAb6IAADj1AAADkFhZWiAAAAAAAABimQAAt4UAABjaWFlaIAAAAAAAACSgAAAPhAAAts9YWVogAAAAAAAA9tYAAQAAAADTLXBhcmEAAAAAAAQAAAACZmYAAPKnAAANWQAAE9AAAApbAAAAAAAAAABtbHVjAAAAAAAAAAEAAAAMZW5VUwAAACAAAAAcAEcAbwBvAGcAbABlACAASQBuAGMALgAgADIAMAAxADb/2wBDAAMCAgMCAgMDAwMEAwMEBQgFBQQEBQoHBwYIDAoMDAsKCwsNDhIQDQ4RDgsLEBYQERMUFRUVDA8XGBYUGBIUFRT/2wBDAQMEBAUEBQkFBQkUDQsNFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBQUFBT/wAARCADbAMcDASIAAhEBAxEB/8QAHQABAAICAwEBAAAAAAAAAAAAAAUGBwgBAwQJAv/EAFcQAAEDBAADAgcIDAkJCQEAAAECAwQABQYRBxIhEzEIFBYiQVGUFRcjVVZh0dMyNkJUcXSBkZOVtNIJNThSU3WSstQkM2JjcnOhs8EYNEVXgoOEscPx/8QAGgEBAQADAQEAAAAAAAAAAAAAAAECAwQFB//EADURAQABAwAFCgMIAwAAAAAAAAABAgMRBBIhMVETFEFSYXGRobHBFSPRBSIyM1OB4fBCQ8L/2gAMAwEAAhEDEQA/APqnSlKBSlKBSvJdLnHs1vfmylFLDKeY8qSpSj3BKUjqpROgEjqSQB1NQfk9Lyb4e/OOsxVbLdnjulCEJ9HbKSduL9YB5BvQCtc6ttNETGtVOI/u5cJmTfbbCcKJFwisLHQpdfSkj8hNdPlVZfjiB7Sj6a6o+F4/EbCGLFbWkAAaREbHd0Hort8lbL8TwPZkfRWfye3yNh5VWX44ge0o+mnlVZfjiB7Sj6aeStl+J4HsyPop5K2X4ngezI+inye3yXYeVVl+OIHtKPpp5VWX44ge0o+mnkrZfieB7Mj6KeStl+J4HsyPop8nt8jYeVVl+OIHtKPprlGTWdxQSi7QVKPoTJQT/wDdceStl+J4HsyPorheJ2NxBSqzW9ST0IMVBB/4U+T2+RsSiVBaQpJCkkbBB2CK5qsLwKDBWp+wKVjssnm/yIajrP8ArGPsFA+kgBXfpQJ3UjY7y5PW/DmMeKXOLoPMg7QsHucbPpQrR0e8EEHqKxqojGtROY8JTHBLUpStKFKUoFKUoFKUoFKUoFKUoFKUoKvdtXbOLTbl6VGgsLuTiD907zBtn8IG3VdfSEHvGxaKrDo8T4ksOL2ET7WppCtdOZl3m1v1kPEj/ZPqqz10Xd1ERux9c+ayUpSudFAhceMHuWUXLHYd4cmXa3KfRIajQJLiA4ykqdbS6lsoW4kA7QlRVsa1vpVZ4U+E9jfEPhnMzC4NS7AxAK1TUPwJXZtI7dxprkcUykPKIQNhvmKSrRAPSqjhwvGOeEAYOF2TLbZityudwkZNBvluKLU25yqUmZCkK9LroSezQpQIWSUoIquYvc86w7wd7hhFnx3J7VllinuplzI1rUrtITlzUp12A4oFt93xdwqSkbOwemwKDOVq8ILAbziGQZPFv27Rj6Su6qdhyGn4aeXm2thbYdGx1HmddHW9VVM78LHFMYtNjuNrbn3yHcb3GtSpLNrm9kG3DtbzSgwQ/pPVIbJ5yfNJ1qsG3bDbxLsvH1NmxvO5MPIcQiItb2RsSpEue8yZCXEjtOZxKtup5WlBKtbKU8vWs7cfrDcU8PcHm2myzLonGshtN1k262sFyT4swsBwNNDqtSQd8o69DQZfs92j320w7lE7bxWWyl9rxhhbDnKobHM24ErQdHqlQBHcQK9lRuOXxvJbJEubUSbAbkp50x7lGXGkIGyNLbWApJ6b0R6RUlQKrGXatdzsN5RpK25iIDx6+ezIUGwn9KWVfkPrqz1WM8T43Fs9vSCXZd1iFIA30ZdEhRPqHKyrr84rosfmRE7unu6fJY3rPSlK50KUpQKUpQKUpQKUpQKUpQKUpQRWRWZV4iNFhxLFwiOiTDfWCQ26AR1AIJSpKlIUAeqVqAI7667XfI18D9vlNCNcUJKZNueOzy9xUnYHO2d9FgaPcdEFImajrzj1uyFptu4RG5PZEqacO0uNKI0VIWNKQddNpINbqaqZjVr3en9/vbe9SB4NnCdJBHDfFgR3EWhj92uP+zXwn/8ALbFf1Qx+7VhODFvpHyK+x0dAEeOB3Q/C4lSj+U7p5EyPlVfv0zP1VZalvr+UmI4rJHjtRI7TDLaWmWkhCG0DSUpA0AB6ABXZVX8iZHyqv36Zn6qnkTI+VV+/TM/VU5O31/KTEcVopWvvgtXrIeMfBe05VfsouqLnKky2nBDU023ytSXG06BbJ+xQN9e+steRMj5VX79Mz9VTk7fX8pMRxeDIuB3DzLrzIu17wiwXe6SeXtpk23NOuucqQlPMpSSTpKQPwAVHq8G/hStKArhxi6ggcqQbSweUbJ0PN9ZJ/LU/5EyPlVfv0zP1VBhLxBCsnvy0nprt2h/xDYNOTt9fykxHF3Wy04vwtx0RbdCt2NWZtZUmPEaSw12ij3JQkDalH0AbJ7tmubPCkXW7C+z2DGKWlMwYq/s2m1EFS1j0LVyp6fcgAd5VXZa8LtVqmiaGnZlwAIEyc+uQ6nfeEqWTyA+pOh81TtSaqaImLfT0/Q2RuKUpWhClKUClKUClKUClKUClKUClKUClKUClKUClKUGu/gB/yYce/Hbj+2vVsRWu/gB/yYce/Hbj+2vVsRQKUpQKUpQKUpQKUpQKUpQKUpQKUpQKUpQKUpQKUpQKUpQKUpQa7+AH/Jhx78duP7a9WxFa7+AH/Jhx78duP7a9WxFApSlApSlApSlApSlApSlApSlApSlApSlApSlApSq3e8oksXFdttENqbNaQlchyQ8WmWAr7EEhKipZAJ5QOgAJKeZO9lFuq5OKV3rJSqR7u5h94WP2t76unu7mH3hY/a3vq66Oa18Y8YMLvWjf8J/wLVlOE2ziTbI5cuNgAh3HkGyqEtZKFf8AtuKPd6HVE9E1tb7u5h94WP2t76uo/IU5FlVhuNlulnsMu23CO5Ekx1y3tONLSUqSfg/SCac1r4x4wYfMf+Dy4KOcU+O8K9yW1CyYkpu6vuDYCpIVuM3sdx508/qIaUPTX1/rXXwdODF08HDBXccszFpuCpEtyZJnyJDiXHlK0EggN6ASgJTodN7PTmNZT93cw+8LH7W99XTmtfGPGDC70qke7uYfeFj9re+rp7u5h94WP2t76unNa+MeMGF3pVLRmF4tI7e+W6Ei3D/OyYElbimB/PUhSBtA6bIOx360CaulaLlqq3+IxgpSlakKUpQKUpQKUpQKUpQKUpQKUpQKoVoPNleZ79FyaA6ejxGMf+pq+1QbP9teaf1m1+wxa7tF/wBnd/1DKN0pulYE4/Z1kdsyI2vDcjvLN5h2pVwetFmskWYlKeZYQ7JekKSENqKSkIQQs8qiN1W8j4wZnPs/D3KZN7fwbCbxjkefNvVutDdwZYuLvIezkhYUpqPyq6LGupPMsVlNUQxbP15LfeIF3MoQZsaaYj6o0gR3Uudi8kAqbXonlUNjaT1GxWEPLrKIPHpy25JksjHMdmS2mseiItTTtuvDSmQSgzNFSJBXz6QVJ2EjlCt1SIWT8QrRh2R+5NxlSEQM9mwr1e7RY4jtxRCQ0nTwjIaSh1fPyBauRS+XuB10aw2vU4hCkJUpKVLOkgnqTrfSv1WrmQtXDPeKXAy42XiRLktTbRdi1ebdAiJDpQlkrcDbjSwlSwQhSSPN7PoEnm3xxv40ZVh+QZLeMUv11u9sxqRGbuFrYscU2uOT2XaMPS1rS8pwpXzfA75OdIUO801htJX5LiA4lBUkLUCQknqQNbOvyj89a/ZZkue3jOOLUOzZibBBxK3xJsGO3bY7/auuRVuFDi3Ek9mS31A0rzuigBqoK1u3viRx44ZZGxks2wLuuAG6uRoceM4hIU9DW4yO1bWeVZWNnfMOQcpGztrDYbNQDht+BAI8QkdCNg/Bqq5WklVqhEkkllBJP+yKp2a/abfvxB//AJaquFn/AIphf7hH90Vb/wCTT3z6QvQ9lKUrzkKUpQKUpQKUpQKUpQKUpQKUpQKoNn+2vNP6za/YYtX6qPdYsrG8guM9MKROt9zWh5a4bRccYdS2hohSAOYpKUJIUN6IIIHTfbosxmqnpmPeJ9mUdKo5jwStWYZS9fjeL5ZZMyEi3XFm0TAw3cI6VKKUO+aVDXaLAU2pCtKI3UFP8Ge0T8QteLeVmWR8ehW4Wly3R7g2hqZFCiQ28Oy/mq5OZHKopABJ76u134lWewWyVcrozdrdb4rZdfly7TKaaaQO9SlqbASB6ya67FxSseUWmNdbMi6XW2SU87EyFapLzLo2RtK0tkEbBHQ+iurkK+rJqzwQU7gLabpllvvEy+5BKgW+axcImPuzUm3MSGUBLS0I5OcBOgQnn5ebqQTXL/A2GiFc2bXlOS4+7cL2/fnpNrmNNuds6kJW3otKSprQ2EqSog9d7A1a/LON8WX79SS/qqeWcYf+GX79SS/qqchX1ZNWeCkveDjjjeN4larVcr3YHsXU+q3XS2y0iWC/vxjnUtC0r7QqKlbT392q8WTeC/juVIyGNJv2SxrTf1iRcbVEnpbjvyeRKPGTpvm5/g0KI5uQqSCUHuq5Y/xXx/LLW3c7Gbjeba6VJRMt9skvsrKVFKgFobIJCgQevQgipHyzjfFl+/Ukv6qpyFfVNWeCJi8KLZHuGYTlzrhIlZTDjwp7jq2+iWmVMpUgBAAUUrJO9jfcAOlQk7wf7M/Bw1uBfL9Yp2K20WmFc7ZJbbkOxeRtBbe5m1IUD2SFHSRojY1Vx8s43xZfv1JL+qqGyfjPi2FR4z+QyZtiYkuhhl25W6RHS64e5CStABUfUOtXkK+rJqzwT+a/abfvxB//AJaquFn/AIphf7hH90VRJ8qTmFvk2m3224sGa0phyXOhLjNx0KBSpfwqQVEDekgHZI3pO1DIjLKY7LbSBpCEhKR6gOgrRpH3bdNE78z7E7IfulKV57EpSlApSlApSlApSlApSlApSvytaW0KUpQSlI2VE6AFB+qpXFjilD4SY5Husq0Xi+uSpjUCNAskNUmQ685vlGh0A6HqSPQBskA+e6cQr0c9xK02DFXr/jN3jOTJmUsS2xEiNBPwYT1JcUtRQRrXmnaebSuXv4VcK4XCe0XKFFu94vjtxnu3GTNvUxUl5bi9DQJ6ABKUjoOutnZoOiBgN7f4i5JfLzlT94xW5wUQYmJPRGxFjp0O1U5sbcUo8w6681ZB5tJ1eWGG4zLbLLaWmm0hCG0JCUpSBoAAdwFdlKBWuPh4cc/eW4Gz2oEjscjyLmtlv5TpbaVD4Z4ekcqDoEdyloNbHVgbwg/A4xDwk8ktt5ye+ZJDct8TxRiJa5TDbCRzqWpfK4ys86uYAkEbCE9OlBqb/Bd8dfcfI7pwvukgJiXTmuNqLivsZKUjtWh/ttpCgO4dkr0qr6V18+vAn8DDDMnxDDeKki8ZFHyKDdnZLceNKYTFUY0taUJUkslZSoNgKHON7VrW6+gtArplQ2JzXZSWG5DfMlfI6gKHMkhSTo+kEAg+ggV3UoMdysRv+IZNmWZWu93nJxOt/NFwyS+0mKmW2gBPYOKA7ILCUpI3ralKPMdamcIzsZPjdhm3i2P4jebq2tSbDd3EJlpWgnnSEg+eBrm2OvKUkhO9C11Vcy4W4rxAuuP3PILLHuVxsEtM22SnNhyM6Ck7SQRsEpSSk7B5RsdBQWqlYqdy3KuFTHEHI+Ic+33DCoLqZlods8J0zWY6iQpp5sbCuTzNKG97UokDonIeN5Fbsux+3Xu0SPG7XcY6JUWQEKR2jS0hSVaUARsEHqBQSVKUoFKUoFKUoFKUoFKUoIjKsusuD2R28ZBc41ntbS223JktwIbQpa0oRtR6DalJG/nqmXDH8k4k3POcXzOyW2Pw6lxkQoC4c53x6ZzJ26tZTyhtPUJCehBQfskkGrHxOsdoyLAL7Cv1lTkdq8WU+9alDfjXZfCpQOo6lSE6+fVccMMyb4g8PbBkbVtkWdu4xEPiBLSQ7H2NFCtgdxGt6699BK4zjNrw3H7fY7JCat1pgMpYjRWRpLaB3Aek/hPUnqak6UoFKUoFdMuWxAivSpTzcaMwhTjrzywlDaANlSiegAAJJNR+VZXZ8Hx6dfb9cWLVaILZdkS5KuVCE/8AUk6AA6kkAAk1rExByfw3Z7cq5Nz8R4EsuBbEAkszsnIOwtzXVuNsAgDqrvGzooC0+AAoL8F7HFpIUlUy4lKgdgjx17qK2KrxWWyW/G7TEtdqhMW62xG0sx4sZsIbaQBoJSkdAK9tApSlApSlBwRsaPUVTL5w1N1z7GcniZHeLQmzNOR3LRDkagzWVJOkutEa2lXKQoddJ16iLpSgpPDvNr9kaLs1lWKuYdNi3F2JFQ/MafbnMjzm3WlJOztJTsEdDsddEC7Vini/Fwl/P+Fq8omTY16avDirA3FBLb0nsjzJd0k6Ty+sjr6aytQKUpQKUpQKUpQKUr8rcQ2NrUEj/SOqDEnhDeEvj3g1Wyz3HJbLf7lAubrjCJNmitutsuJCVBDqnHEBKlgqKQNkhtf82tK4X8JznlymLsOPY5bbtdJ19W3bLheUlO4biylhhcdlSdOjaNrDqh3jR6Kr6D8SMExzivhd0xbI2WptquDRbcSVDmbV9y4gn7FaTog+givmhwj8FC88MfDixTGL0343ZYMpd6h3ZKfgpMdhKnGl9/RXaJbSpJO0k+kEE3Ej6s0rq8aZ/pm/7Qp40z/TN/2hTEjtqn8VeLGM8GMOl5LlVxTAt7HmoQPOdkOEea00jvWs67vwkkAEiC448fce4GY0xOuCXbteLg54taLFbxzyrjI6ANtpG9Dak7VrpsdCSlJxzwq4CZFnuYROKXGtTM/JmvPsmLNnmgWBBOx5vULf7tqO9EA7JCSmCJxXhdlPhS5DBzji3AcsuDxHBIx/h84o/CfzZM8fdKIPRs929EAcwXtO22hltDbaEobQAlKUjQAHcAK/VKBSlKBSlKBSlfhbqG9c60p33cx1QfuvJdn5cW1TXrfFROntsrXHiuvdil5wJJSgr5VcgJ0ObR1vej3V3eNM/wBM3/aFPGmf6Zv+0KuJHzoyD+FJQ5eIouXBeL4/apCykTrwFvRnRtKuQmKC2vvBPf6K298F3j3J8I7hs9lz+MLxVr3QdhsR1zPGg+hCEEupX2bfTmUtGtHq2evoGjfh0eC3Pe8I2xTcTjpci5/KDZCB8HHn7AeUsgealSSHST/rT3Jr6M8NcKs3C3ArFidnU2i32mKiM2dgFwjqpxWvulqKlH51GmJFqpXV40z/AEzf9oVyJDSiAHUEnuAUKYkdlKUqBSlKDy3Sb7m2yXL5ebsGVu8vr5Uk/wDSseWvErVfrdEuV5t8S8XKUyh56TOYS8ragCUp5h5qB3BI0ND17NXnKvtYvH4m9/cNV7GvtctX4o1/cFelo8zRbmqmcTlluh4ve+xb5NWf2Br92nvfYt8mrP7A1+7VF4V+EVYuJIykuNSbMixzJiFvTYclljxVhYT2y3nWkIQo75i0TzoG9joTVgwjjbhXEWe/CsN7EqW1H8bLL8Z6MpbG9ds32qE9o3sgc6Np6jr1FbYv3J/znxTM8U1732LfJqz+wNfu0977Fvk1Z/YGv3agMS48YJnV/RZrJkDc2e6lxcdJjvNNykt/Zlh1aAh4J9JbUrp17qrWD+EPa18HsTy7Npce1zr4XG241uivvF1xK3BpplAccOko2e/XedU5xc68+JmeLIZ4fYz0Ldgt0dwdUvRoyGXEH1pWgBST84IIqxYJdJF0sBMp0yJEaTIhqeOtuBp1SEqOgBzFKQToAb3rpXgsl5h5HZ4V1tz3jECayiQw9ylPO2obSrSgCNgjvFfrhn/Elw/rad+0LrC9VNyzM1TnEx7rnMbVupSleWxKUpQK8t0ukWy2+ROmvJjxGEFbjiu4AfMOpPqA6k9BXqrEHHW8uOzrNY0K0wUrnSE7+yKSEtD5xsrV+FCa7ND0edKv02uPosK5lXEW85Y+4lmRIs9q2Q3Fjr7N5xPoLjifOBP81JAG9Hm1uqaqw21xaluQI7ritcy3WgtSvwk9TXupX0ezao0enUtRiGOtKP8AJ61fFkP2dH0U8nrV8WQ/Z0fRUhVQvPFzEsfvLlrn3hDEppSUPHsXFNMKVrlS66lJQ2TsdFKHeK2VXYojNVWP3MzxT/k9aviyH7Oj6KeT1q+LIfs6Poqu3zjDiOOXOdb7hdizLgKQJaERXnBHCkJWlTikoISgpWnzyQnvG9ggevKOJmNYc/DZut0Sy/LQXWWmWnH1qbHe5ytpUQj/AEjofPWPL0Rn7+7ftMzxS/k9aviyH7Oj6KHHbUQR7mQ9Hp/3dH0VBcJ8ul55w7sl/nNsNSpzJccRGSUtg8yh5oJJ7gO8mrbWVFzXpiqJ2SZni77Jcbhi7iV2ae/bwkj4BKiphQ9RaPm/lAB9RFZx4fZ8zmcNbbyExbtHA8YjJO0kHoHEE96Tr8IPQ+gnA9eux3hzG8ltN1bVyhqQhl7r9kw4oIcB9ethWvWgV5Wn6DRpVuaoj78bp9pWJzsls3SlK+eiLyr7WLx+Jvf3DVexr7XLV+KNf3BVkyNlcjHro02kqcXFdSlI9JKCBVaxdaXMatKknaVRGSD6xyCvQs/kz3+y9DWa6YnkV44fcauGrWP3di93e73S7W6YuItNvmMuupebQJP2AUsbbKSQQd70KkMut978IHKbT7iYxfMPjWrG7zCkTL7BVB5X5kZLLUdoHq4EKHOVJBQOROiSa2cpTVRrDjyL3m7vBbHI+FXzGZGGSGZV4m3KCY8aOliG5HUww6fNeDiljRbJHKNnVQ2P2BVp4H4fa79jWdWfK8VuMyNDuuOWtUiRDf2s9shI5g9HdQ6Ek8qkq6g61sbb0pqio8JLjk124a47MzKImDk70RCp7CUhPK586QSEqI0SkdxJHoqx8M/4kuH9bTv2hdeuvNw1QU2GYv7ly6TlJOu8eMuDf/A//wArKvZYq74916FspSleahSlKBWEON0VUfNbVKV/m5UBbKTr7ptzmI/M6PzH1Vm+qzxAw5OaWExULSzOYWH4jy96Q4ARpWvuVAlJ+Y77wK9L7P0inRtJprr3bp/dYa/0pLjOR5Ei3z4yo8praH4rw6j0f+pJ9BHQiqaODGBA7GG2MH+r2v3a+hTVVMRNGJjv/iWC5VrlEwtm3XTKLDk9jzO5e6l3kvtO2eXL9z5caQvYLgbcS2ggKIWFgdE+mste8vgPyMsX6va/dq4ssojtIaaQlttCQlKEjQSB0AFaK7M3sa8RGP39YGHHsXmse/XHatsosTILLMEFlavGQm2pb02SPhDzDl6b69O+vBiarnw8yxm53PHbzdI92x22RWX4EJT7kR1hCg4w4kdW+YrCtnQ2Ds9OmdKVObRmKonExmfGZn3FA4CW2ZaOEGMw58R+BMajqDkaS2W3Gz2ijpST1B61f6rt+4dYtlE7x28Y7bLpL5A328uKhxfKO4bI3rqajveWwH5GWL9Xtfu1soprt0xRTETEbN/8C511PxVXFyJBb6uy5TMdA1vqpxI3+QbP5KjrFjNkw2E8zaLbCs0Ra+1cRFaSygq0BzHQA3oAb+asu8JcEffnsZJcWVMstJV4hHcSQslQ5S8oHu83YSPUpR9IrXpOkxotmble/o71p35ZfpSlfM1Kqcrh8nt3F2y93KxsrUVmLDDC2Qo9SUpdaXy7PXSSBsk661bKVsouVW/wyucKb5AXD5Z3v9BC/wAPTyAuHyzvf6CF/h6uVK3c5udnhH0Mqb5AXD5Z3v8AQQv8PTyAuHyzvf6CF/h6uVKc5udnhH0Mqgjh/IX5srKr1KZP2TX+TM8w9I52mUrH4UqB9RFWmHDYt0RmLFZRHjMoDbbTSQlKEgaAAHcK7qVrru13NlU+3oZyUpStKFKUoFKUoIXJMNs2XNIRdYKJC2wQ28CUOt77+VxJCk/kPWqU9wDtalks329R0HuQFsLA/AVNE/nJrJ9K7LWmaRYjVt1zELliz3gYPylvf5ov1FPeBg/KW9/mi/UVlOlb/iel/qen0MsWe8DB+Ut7/NF+op7wMH5S3v8ANF+orKdKfE9L/U9PoZYs94GD8pb3+aL9RXI4AwN9ckvZH/xR/wDhWUqU+J6X+p6GVKsHCDHLDIbkqYeuktshSHri52vKR3EI0EA/OEg1daUriu3rl6rWuVTM9pkpSlaUf//Z",
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from langchain_core.messages import ToolMessage\n",
    "from langchain_core.runnables import RunnableLambda\n",
    "from langgraph.prebuilt import ToolNode\n",
    "\n",
    "\n",
    "def create_tool_node_with_fallback(tools: list) -> dict:\n",
    "    return ToolNode(tools).with_fallbacks(\n",
    "        [RunnableLambda(handle_tool_error)], exception_key=\"error\"\n",
    "    )\n",
    "\n",
    "\n",
    "def handle_tool_error(state: State) -> dict:\n",
    "    error = state.get(\"error\")\n",
    "    tool_calls = state[\"messages\"][-1].tool_calls\n",
    "    return {\n",
    "        \"messages\": [\n",
    "            ToolMessage(\n",
    "                content=f\"Error: {repr(error)}\\n please fix your mistakes.\",\n",
    "                tool_call_id=tc[\"id\"],\n",
    "            )\n",
    "            for tc in tool_calls\n",
    "        ]\n",
    "    }\n",
    "\n",
    "\n",
    "from IPython.display import Image, display\n",
    "from langgraph.checkpoint.sqlite import SqliteSaver\n",
    "from langgraph.graph import END, START, StateGraph\n",
    "from langgraph.prebuilt import tools_condition\n",
    "\n",
    "# Graph\n",
    "builder = StateGraph(State)\n",
    "\n",
    "# Define nodes: these do the work\n",
    "builder.add_node(\"assistant\", Assistant(assistant_runnable))\n",
    "builder.add_node(\"tools\", create_tool_node_with_fallback(tools))\n",
    "\n",
    "# Define edges: these determine how the control flow moves\n",
    "builder.add_edge(START, \"assistant\")\n",
    "builder.add_conditional_edges(\n",
    "    \"assistant\",\n",
    "    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools\n",
    "    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END\n",
    "    tools_condition,\n",
    ")\n",
    "builder.add_edge(\"tools\", \"assistant\")\n",
    "\n",
    "# The checkpointer lets the graph persist its state\n",
    "memory = SqliteSaver.from_conn_string(\":memory:\")\n",
    "react_graph = builder.compile(checkpointer=memory)\n",
    "\n",
    "# Show\n",
    "display(Image(react_graph.get_graph(xray=True).draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "43c633d5-e7a7-4b7c-8dc7-760a3b032e95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import uuid\n",
    "\n",
    "def predict_react_agent_answer(example: dict):\n",
    "    \"\"\"Use this for answer evaluation\"\"\"\n",
    "\n",
    "    config = {\"configurable\": {\"thread_id\": str(uuid.uuid4())}}\n",
    "    messages = react_graph.invoke({\"messages\": (\"user\", example[\"input\"])}, config)\n",
    "    return {\"response\": messages[\"messages\"][-1].content, \"messages\": messages}\n",
    "\n",
    "\n",
    "example = {\"input\": \"Get me information about the the types of LLM agent memory?\"}\n",
    "response = predict_react_agent_answer(example)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf82fa52-9e6c-4f37-94ae-91450dac602e",
   "metadata": {},
   "source": [
    "See trace with llama3.1 here:\n",
    "\n",
    "https://smith.langchain.com/public/44d0c7dd-a756-47ad-8025-ee7ae6469ecb/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cd74a0b3-be40-46cd-97bf-ef9676878289",
   "metadata": {},
   "outputs": [],
   "source": [
    "example = {\"input\": \"Get me information about the current weather in SF.\"}\n",
    "response = predict_react_agent_answer(example)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cac91bf-c975-44a2-a9fd-99706fee5735",
   "metadata": {},
   "source": [
    "See trace with llama3.1 here:\n",
    "\n",
    "https://smith.langchain.com/public/7a4938e3-f94f-4e04-a162-bf592fba4643/r"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74b813cb-18ed-42d8-b313-6ee56ded4bcc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
